import json
from collections import OrderedDict
from typing import List, Tuple, Iterable, Optional
import os

import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold

from Utils import get_logger, logger, config as cfg
from Utils.Constants import FileNamesConstants
from Utils.utils import get_model_id_from_file_name


def stats_folders_train_test_split(is_weights: bool, folders: Iterable[str], val_size: Optional[float]) -> Tuple[List, List]:
    """
    Collects all relevant files (weights_stats/gradients_stats) from the folders list and splits the files of each
    folder to train and validation.
    ** splitting each folder assures that train test split will be stratified split according to the folders split
    :param is_weights: True for weights stats, False for gradients stats
    :param folders: absolute path for stats folders
    :param val_size: set to 0 if no validation is required will return an empty list as validation files
    :return: Tuple of 2 lists, each list is a list of absolute paths of files for training and validation
    """
    logger().log('stats_folders_train_test_split', 'val size', val_size, '\nstats loaded from folders:', folders)
    val_size = 0 if val_size is None else val_size
    file_type = FileNamesConstants.WEIGHTS_STATS if is_weights else FileNamesConstants.GRADIENTS_STATS
    all_data = dict()
    for curr_folder in folders:
        tmp = [file_name for file_name in os.listdir(curr_folder) if file_type in file_name]
        if len(tmp) != 0:
            all_data[curr_folder] = tmp

    train_data = list()
    val_data = list()
    for curr_folder, files in all_data.items():
        if val_size == 0:
            train_files = files
            val_files = []
        else:
            train_files, val_files = train_test_split(files, test_size=val_size, random_state=cfg.seed)

        train_data += [os.path.join(curr_folder, file) for file in train_files]
        val_data += [os.path.join(curr_folder, file) for file in val_files]

    logger().log('stats_folders_train_test_split', 'train stats files:', train_data)
    logger().log('stats_folders_train_test_split', 'validation stats files:', val_data)
    return train_data, val_data


def test_train_test_split():
    base_ = '/sise/group/'
    test_folders_ = ['models_0_115/stats', 'models_500_650/stats']
    stats_folders_train_test_split(True, [os.path.join(base_, x) for x in test_folders_], 0.2)


def collect_models_results(data_files: Iterable[str], results_folders: Iterable[str], metric_name: str) -> Tuple[List[str], OrderedDict]:
    """
    Find results for all files.
    :param data_files: list of files with model id to find their result
    :param results_folders: list of folders with results_ files for finding the resul
    :param metric_name: metric key in results file
    :return: Tuple with all models ids and a map form model id to result value
    """
    results_folders = sorted(results_folders, reverse=True)     # to keep results from the higher models idx
    models_ids = [get_model_id_from_file_name(curr_file) for curr_file in data_files]
    models_results_map = OrderedDict()
    for curr_folder in results_folders:
        results_files = filter(lambda name: FileNamesConstants.RESULTS in name, os.listdir(curr_folder))
        for curr_file in results_files:
            with open(os.path.join(curr_folder, curr_file), 'r') as jf:
                if 'new_format' in curr_file and 'json' in curr_file:
                    full_res = json.load(jf)
                    for curr_model_id, res in full_res.items():
                        if curr_model_id in models_ids:
                            models_results_map[curr_model_id] = res[metric_name]
                else:
                    for idx, line in enumerate(jf):
                        data = json.loads(line)
                        if isinstance(data, str):
                            # curr_model_id = get_model_id_from_file_name(data)
                            continue
                        else:
                            curr_model_id = str(data['model_id'])
                            if curr_model_id in models_ids:
                                if curr_model_id in models_results_map:
                                    logger().warning('DataUtils::collect_models_results',
                                                     f'Model: {curr_model_id} appears twice in results files. '
                                                     f'From folder: {curr_folder}, {curr_file}')
                                models_results_map[curr_model_id] = data[metric_name]
    return models_ids, models_results_map


if __name__ == '__main__':
    get_logger(os.path.basename(__file__).split('.')[0])
    test_train_test_split()


def stratified_continues_k_fold_gen(x_vals: np.array, y_vals: np.array, num_folds: int, num_bins: int = 150):
    """
    generate stratified k-fold for continous labels
    :param x_vals:
    :param y_vals:
    :param num_folds:
    :param num_bins: number of bins used for splitting the data into different bins
    :return:
    """
    if not isinstance(x_vals, np.ndarray):
        x_vals = np.array(x_vals)
    if not isinstance(y_vals, np.ndarray):
        y_vals = np.array(y_vals)
    y_bins = np.digitize(y_vals, np.histogram_bin_edges(y_vals, num_bins)[:-1])
    skf = StratifiedKFold(n_splits=num_folds, random_state=cfg.seed, shuffle=True)
    for train_indices, test_indices in skf.split(x_vals, y_bins):
        yield x_vals[train_indices], y_vals[train_indices], x_vals[test_indices], y_vals[test_indices]
